{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "BERT_base_imdb.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "7_FXneEfpdMM",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 347
        },
        "outputId": "cb867607-347f-437c-e740-850fb7e4e288"
      },
      "source": [
        "!pip install pytorch_pretrained_bert pytorch-nlp"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pytorch_pretrained_bert in /usr/local/lib/python3.6/dist-packages (0.6.2)\n",
            "Requirement already satisfied: pytorch-nlp in /usr/local/lib/python3.6/dist-packages (0.5.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch_pretrained_bert) (1.18.5)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pytorch_pretrained_bert) (2.23.0)\n",
            "Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch_pretrained_bert) (1.14.24)\n",
            "Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from pytorch_pretrained_bert) (2019.12.20)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch_pretrained_bert) (4.41.1)\n",
            "Requirement already satisfied: torch>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from pytorch_pretrained_bert) (1.5.1+cu101)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_pretrained_bert) (3.0.4)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_pretrained_bert) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_pretrained_bert) (2020.6.20)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_pretrained_bert) (2.10)\n",
            "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch_pretrained_bert) (0.10.0)\n",
            "Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch_pretrained_bert) (0.3.3)\n",
            "Requirement already satisfied: botocore<1.18.0,>=1.17.24 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch_pretrained_bert) (1.17.24)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.1->pytorch_pretrained_bert) (0.16.0)\n",
            "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.18.0,>=1.17.24->boto3->pytorch_pretrained_bert) (2.8.1)\n",
            "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.18.0,>=1.17.24->boto3->pytorch_pretrained_bert) (0.15.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.18.0,>=1.17.24->boto3->pytorch_pretrained_bert) (1.15.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "EnVIV6Vt8f4d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "042018b4-0c4e-4a5b-9a27-be901e4f1e5e"
      },
      "source": [
        "import sys\n",
        "import numpy as np\n",
        "import random as rn\n",
        "import torch\n",
        "from pytorch_pretrained_bert import BertModel\n",
        "from torch import nn\n",
        "from torchnlp.datasets import imdb_dataset\n",
        "from pytorch_pretrained_bert import BertTokenizer\n",
        "from keras.preprocessing.sequence import pad_sequences\n",
        "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
        "from torch.optim import Adam\n",
        "from torch.nn.utils import clip_grad_norm_\n",
        "from IPython.display import clear_output"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using TensorFlow backend.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cUYrv06z8gaF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "rn.seed(321)\n",
        "np.random.seed(321)\n",
        "torch.manual_seed(321)\n",
        "torch.cuda.manual_seed(321)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rgkbhHcB17GY",
        "colab_type": "text"
      },
      "source": [
        "## Prepare the Data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ur8i7boP6qtb",
        "colab": {}
      },
      "source": [
        "train_data, test_data = imdb_dataset(train=True, test=True)\n",
        "rn.shuffle(train_data)\n",
        "rn.shuffle(test_data)\n",
        "train_data = train_data[:1000]\n",
        "test_data = test_data[:100]"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l1VENo6tqG7J",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "821f3ace-ecd8-44dd-9e5a-4fbb3c19ac91"
      },
      "source": [
        "train_texts, train_labels = list(zip(*map(lambda d: (d['text'], d['sentiment']), train_data)))\n",
        "test_texts, test_labels = list(zip(*map(lambda d: (d['text'], d['sentiment']), test_data)))\n",
        "\n",
        "len(train_texts), len(train_labels), len(test_texts), len(test_labels)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1000, 1000, 100, 100)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ty24UrRjqIsb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "26trq3gIrJeG",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "fddf084c-d510-4934-d188-b90837e69bb9"
      },
      "source": [
        "tokenizer.tokenize('Hi my name is Dima')"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['hi', 'my', 'name', 'is', 'dim', '##a']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1k9rcOzQr5Zm",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "a39ce5a6-5d60-4197-a33e-202884008fee"
      },
      "source": [
        "train_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:510] + ['[SEP]'], train_texts))\n",
        "test_tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t)[:510] + ['[SEP]'], test_texts))\n",
        "\n",
        "len(train_tokens), len(test_tokens)                   \n",
        "                   "
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1000, 100)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9Ca7KKnhuT5c",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "67c6e162-415a-4a67-af76-9d5aa02541dc"
      },
      "source": [
        "train_tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, train_tokens)), maxlen=512, truncating=\"post\", padding=\"post\", dtype=\"int\")\n",
        "test_tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, test_tokens)), maxlen=512, truncating=\"post\", padding=\"post\", dtype=\"int\")\n",
        "\n",
        "train_tokens_ids.shape, test_tokens_ids.shape"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((1000, 512), (100, 512))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "F7POtHuIOV-6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "0cfd64f9-e03e-44ef-db4d-c764b9572074"
      },
      "source": [
        "train_y = np.array(train_labels) == 'pos'\n",
        "test_y = np.array(test_labels) == 'pos'\n",
        "train_y.shape, test_y.shape, np.mean(train_y), np.mean(test_y)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((1000,), (100,), 0.489, 0.5)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U-xXMEqXOWTE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_masks = [[float(i > 0) for i in ii] for ii in train_tokens_ids]\n",
        "test_masks = [[float(i > 0) for i in ii] for ii in test_tokens_ids]"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2K4JQMFo1-_S",
        "colab_type": "text"
      },
      "source": [
        "# Baseline"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wdzjl_WlwpKr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from sklearn.feature_extraction.text import CountVectorizer\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.pipeline import make_pipeline\n",
        "from sklearn.metrics import classification_report"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9jyb-hJ0xAgG",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 156
        },
        "outputId": "198c2afb-e99b-477e-e105-2686c916d2a4"
      },
      "source": [
        "baseline_model = make_pipeline(CountVectorizer(ngram_range=(1,3)), LogisticRegression()).fit(train_texts, train_labels)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
            "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
            "\n",
            "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
            "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
            "Please also refer to the documentation for alternative solver options:\n",
            "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
            "  extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q9IzjAX_2VLf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "baseline_predicted = baseline_model.predict(test_texts)"
      ],
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QnsCRIaQ3GPQ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 173
        },
        "outputId": "3436b3b5-633d-4d7e-8150-cbca4e85f6bd"
      },
      "source": [
        "print(classification_report(test_labels, baseline_predicted))"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.80      0.82      0.81        50\n",
            "         pos       0.82      0.80      0.81        50\n",
            "\n",
            "    accuracy                           0.81       100\n",
            "   macro avg       0.81      0.81      0.81       100\n",
            "weighted avg       0.81      0.81      0.81       100\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r_hEhebQ3YqI",
        "colab_type": "text"
      },
      "source": [
        "# Bert Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "E234ByBa3Qtb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BertBinaryClassifier(nn.Module):\n",
        "    def __init__(self, dropout=0.1):\n",
        "        super(BertBinaryClassifier, self).__init__()\n",
        "\n",
        "        self.bert = BertModel.from_pretrained('bert-base-uncased')\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "        self.linear = nn.Linear(768, 1)\n",
        "        self.sigmoid = nn.Sigmoid()\n",
        "    \n",
        "    def forward(self, tokens, masks=None):\n",
        "        _, pooled_output = self.bert(tokens, attention_mask=masks, output_all_encoded_layers=False)\n",
        "        dropout_output = self.dropout(pooled_output)\n",
        "        linear_output = self.linear(dropout_output)\n",
        "        proba = self.sigmoid(linear_output)\n",
        "        return proba\n",
        "        "
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ED9SE1Ka8W9x",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "dfef532d-1e09-4c65-e0e2-dddbb13a4a7f"
      },
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "device"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "device(type='cuda')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FNZ_3auqDbjl",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "outputId": "dffff4cc-014c-41cc-e9bc-5c44781c70f7"
      },
      "source": [
        "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'0.0M'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Sf9n8zouENRi",
        "colab": {}
      },
      "source": [
        "bert_clf = BertBinaryClassifier()\n",
        "bert_clf = bert_clf.cuda()\n"
      ],
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LbHkaJuZEkYr",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "outputId": "3f4e53e4-a1dc-472e-ab09-d9a649542c40"
      },
      "source": [
        "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'439.065088M'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LOQ-870M7VWy",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "d6a6ff95-6666-4701-bc1b-c2d3b03f982d"
      },
      "source": [
        "x = torch.tensor(train_tokens_ids[:3]).to(device)\n",
        "y, pooled = bert_clf.bert(x, output_all_encoded_layers=False)\n",
        "x.shape, y.shape, pooled.shape"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([3, 512]), torch.Size([3, 512, 768]), torch.Size([3, 768]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LCb_pK4X7hb9",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 69
        },
        "outputId": "01e0a6cf-8a1e-42fd-d7a7-f172cc19b6ee"
      },
      "source": [
        "y = bert_clf(x)\n",
        "y.cpu().detach().numpy()"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[0.42505887],\n",
              "       [0.46727142],\n",
              "       [0.42605487]], dtype=float32)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MUm-gFCuFkoI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "outputId": "d664dfec-a8c2-4221-dfe1-3334727d771c"
      },
      "source": [
        "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'6697.349632M'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KzzsUZOUFcxp",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "outputId": "e140bb9e-cb21-43d5-efc8-85996a283e21"
      },
      "source": [
        "y, x, pooled = None, None, None\n",
        "torch.cuda.empty_cache()\n",
        "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'439.065088M'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9LPIYcn99r8",
        "colab_type": "text"
      },
      "source": [
        "# Fine-tune BERT"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZUkXhM1k_TAl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 8\n",
        "EPOCHS = 10"
      ],
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jGwV0yqg_o2u",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "outputId": "d20bef39-d694-48fc-b9e9-258a845563e9"
      },
      "source": [
        "train_tokens_tensor = torch.tensor(train_tokens_ids)\n",
        "train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()\n",
        "\n",
        "test_tokens_tensor = torch.tensor(test_tokens_ids)\n",
        "test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()\n",
        "\n",
        "train_masks_tensor = torch.tensor(train_masks)\n",
        "test_masks_tensor = torch.tensor(test_masks)\n",
        "\n",
        "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'439.065088M'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2Yl2JpCe9YAu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)\n",
        "train_sampler = RandomSampler(train_dataset)\n",
        "train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)\n",
        "\n",
        "test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)\n",
        "test_sampler = SequentialSampler(test_dataset)\n",
        "test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)\n"
      ],
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JF_QD0naS8EQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "param_optimizer = list(bert_clf.sigmoid.named_parameters()) \n",
        "optimizer_grouped_parameters = [{\"params\": [p for n, p in param_optimizer]}]"
      ],
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b28PcoDh_cyd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "optimizer = Adam(bert_clf.parameters(), lr=3e-6)"
      ],
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z6yIChYvBP5F",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        " torch.cuda.empty_cache()"
      ],
      "execution_count": 30,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mqh8tCl4AFjo",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "cda43fda-c682-481c-a34b-0f9c2bb337e9"
      },
      "source": [
        "for epoch_num in range(EPOCHS):\n",
        "    bert_clf.train()\n",
        "    train_loss = 0\n",
        "    for step_num, batch_data in enumerate(train_dataloader):\n",
        "        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)\n",
        "        print(str(torch.cuda.memory_allocated(device)/1000000 ) + 'M')\n",
        "        logits = bert_clf(token_ids, masks)\n",
        "        \n",
        "        loss_func = nn.BCELoss()\n",
        "\n",
        "        batch_loss = loss_func(logits, labels)\n",
        "        train_loss += batch_loss.item()\n",
        "        \n",
        "        \n",
        "        bert_clf.zero_grad()\n",
        "        batch_loss.backward()\n",
        "        \n",
        "\n",
        "        clip_grad_norm_(parameters=bert_clf.parameters(), max_norm=1.0)\n",
        "        optimizer.step()\n",
        "        \n",
        "        clear_output(wait=True)\n",
        "        print('Epoch: ', epoch_num + 1)\n",
        "        print(\"\\r\" + \"{0}/{1} loss: {2} \".format(step_num, len(train_data) / BATCH_SIZE, train_loss / (step_num + 1)))\n",
        "        "
      ],
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch:  10\n",
            "\r124/125.0 loss: 0.05438226660899818 \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GHFWhkRYHv5l",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "bert_clf.eval()\n",
        "bert_predicted = []\n",
        "all_logits = []\n",
        "with torch.no_grad():\n",
        "    for step_num, batch_data in enumerate(test_dataloader):\n",
        "\n",
        "        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)\n",
        "\n",
        "        logits = bert_clf(token_ids, masks)\n",
        "        loss_func = nn.BCELoss()\n",
        "        loss = loss_func(logits, labels)\n",
        "        numpy_logits = logits.cpu().detach().numpy()\n",
        "        \n",
        "        bert_predicted += list(numpy_logits[:, 0] > 0.5)\n",
        "        all_logits += list(numpy_logits[:, 0])\n",
        "    "
      ],
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vg_sX9BjooL-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "877c47de-7463-4344-f321-0d4f0d3dd14b"
      },
      "source": [
        "np.mean(bert_predicted)"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.57"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 33
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-DmIJqUnkVM8",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 173
        },
        "outputId": "b2206b2e-07aa-4d4a-8d2a-7ecffc9d097c"
      },
      "source": [
        "print(classification_report(test_y, bert_predicted))"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "       False       0.93      0.80      0.86        50\n",
            "        True       0.82      0.94      0.88        50\n",
            "\n",
            "    accuracy                           0.87       100\n",
            "   macro avg       0.88      0.87      0.87       100\n",
            "weighted avg       0.88      0.87      0.87       100\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eIBvoExLpOne",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 34,
      "outputs": []
    }
  ]
}